#ifndef XG_OPENSSL_SSLSOCKET_CPP
#define XG_OPENSSL_SSLSOCKET_CPP
/////////////////////////////////////////////////////////////////
#include "../SSLSocket.h"

#ifdef _MSC_VER
#pragma comment(lib, "libeay32.lib")
#pragma comment(lib, "ssleay32.lib")
#endif

struct SSLContextSetup
{
	SSLContextSetup()
	{
		SSL_library_init();
		SSL_load_error_strings();
		OpenSSL_add_all_algorithms();

#if OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_1_0_0
		CRYPTO_set_id_callback(GetThreadId);
#else
		CRYPTO_THREADID_set_callback(GetThreadIdEx);
#endif
		CRYPTO_set_locking_callback(GetThreadLockFunc);

		GetGlobalContext(true);
	}
	static u_long GetThreadId()
	{
		return (u_long)(GetCurrentThreadId());
	}
	static SSLContextSetup* Instance()
	{
		XG_DEFINE_GLOBAL_VARIABLE(SSLContextSetup)
	}
#if OPENSSL_VERSION_NUMBER >= OPENSSL_VERSION_1_0_0
	static void GetThreadIdEx(CRYPTO_THREADID* id)
	{
		CRYPTO_THREADID_set_numeric(id, GetThreadId());
	}
#endif
	static SSLContext* GetGlobalContext(bool inited)
	{
		static SSLContext* context = NULL;

		if (context == NULL)
		{
			if (Process::GetGlobalVariable("SSLSOCKET_GLOBALCONTEXT", context, inited))
			{
				if (inited) context->init();
			}
			else
			{
				if (inited) ErrorExit(XG_SYSERR);
			}
		}

		return context;
	}
	static void GetThreadLockFunc(int mode, int type, const char* file, int line)
	{
		static vector<Mutex> vec(CRYPTO_num_locks());

		if (mode & CRYPTO_LOCK)
		{
			vec[type].lock();
		}
		else
		{
			vec[type].unlock();
		}
	}
};

static SSLContextSetup* setup = SSLContextSetup::Instance();

SSLContext::~SSLContext()
{
	destroy();
}
void SSLContext::destroy()
{
	if (ctx)
	{
		SSL_CTX_free(ctx);
		ctx = NULL;
	}
}
bool SSLContext::setMode(long mode)
{
	return SSL_CTX_set_mode(ctx, mode) > 0;
}
SSL_CONTEXT SSLContext::getHandle() const
{
	return ctx;
}
bool SSLContext::init(int version)
{
	destroy();

#if OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_1_0_0
	if (version == 2)
	{
		ctx = SSL_CTX_new(SSLv2_method());
	}
	else if (version == 3)
	{
		ctx = SSL_CTX_new(SSLv3_method());
	}
	else
#endif
	{
		ctx = SSL_CTX_new(SSLv23_method());
	}

	CHECK_FALSE_RETURN(ctx);
	
	setMode(SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);

	return true;
}
bool SSLContext::setClientVerify(bool checked)
{
	SSL_CTX_set_verify(ctx, checked ? SSL_VERIFY_PEER : SSL_VERIFY_NONE, NULL);

	return true;
}
bool SSLContext::setCipherList(const string& str)
{
	return SSL_CTX_set_cipher_list(ctx, str.c_str()) > 0;
}
bool SSLContext::setClientCertificate(const string& certfile)
{
	return SSL_CTX_load_verify_locations(ctx, certfile.c_str(), NULL) > 0;
}
bool SSLContext::setCertificate(const string& certfile, int type)
{
	return SSL_CTX_use_certificate_file(ctx, certfile.c_str(), type) > 0;
}
bool SSLContext::setCertPrivateKey(const string& keyfile, int type)
{
	CHECK_FALSE_RETURN(SSL_CTX_use_PrivateKey_file(ctx, keyfile.c_str(), type) > 0);

	return (SSL_CTX_check_private_key(ctx) == 0) ? false : true;
}
SSLContext* SSLContext::GetGlobalContext()
{
	return SSLContextSetup::GetGlobalContext(false);
}
int SSLContext::SelectALPN(SSL* ssl, const u_char** out, u_char* outlen, const u_char* in, u_int inlen, void* arg)
{
	int res = 0;

#if OPENSSL_VERSION_NUMBER >= OPENSSL_VERSION_1_0_2
	static u_char srv[] = {2, 'h', '2', 8, 'h', 't', 't', 'p', '/', '1', '.', '1'};
	res = SSL_select_next_proto((u_char**)(out), outlen, srv, sizeof(srv), in, inlen);
#endif

	return res == OPENSSL_NPN_NEGOTIATED ? SSL_TLSEXT_ERR_OK : SSL_TLSEXT_ERR_NOACK;
}

SSLSocket::~SSLSocket()
{
	close();
}
void SSLSocket::close()
{
	if (ssl)
	{
		SSL_shutdown(ssl);
		SSL_free(ssl);
		ssl = NULL;
	}

	ctx = NULL;

	Socket::close();
}
int SSLSocket::accept()
{
	int res = SSL_accept(ssl);
	
	if (res > 0) return res;

	res = SSL_get_error(ssl, res);

	if (res == SSL_ERROR_WANT_READ) return XG_TIMEOUT;

	if (res == SSL_ERROR_WANT_WRITE) return XG_TIMEOUT;

	return XG_NETERR;
}
SSL_SOCKET SSLSocket::getHandle() const
{
	return ssl;
}
bool SSLSocket::init(SOCKET sock)
{
	return init(sock, SSLContext::GetGlobalContext());
}
bool SSLSocket::init(SOCKET sock, SSLContext* ctx)
{
	this->close();
	this->ctx = ctx;

	CHECK_FALSE_RETURN(ctx && ctx->getHandle());
	CHECK_FALSE_RETURN(ssl = SSL_new(ctx->getHandle()));

	if (IsSocketClosed(sock)) return true;

	CHECK_FALSE_RETURN(SSL_set_fd(ssl, sock) >= 0);

	this->sock = sock;

	return true;
}
bool SSLSocket::connect(const string& ip, int port, int timeout)
{
	SOCKET sock = SocketConnectTimeout(ip.c_str(), port, timeout);

	if (IsSocketClosed(sock)) return false;

	if (init(sock, SSLContext::GetGlobalContext()))
	{
		SocketSetSendTimeout(sock, 1000);
		SocketSetRecvTimeout(sock, 1000);

		if (SSL_connect(ssl) >= 0) return true;

		this->close();
	}
	else
	{
		SocketClose(sock);
	}

	return false;
}
int SSLSocket::peek(void* data, int size)
{
	if (ssl == NULL) return XG_NETCLOSE;

	int res = SSL_peek(ssl, data, size);

	if (res > 0) return res;

	res = SSL_get_error(ssl, res);

	if (res == SSL_ERROR_WANT_READ) return 0;

	if (res == SSL_ERROR_WANT_WRITE) return 0;

	return res == SSL_ERROR_ZERO_RETURN ? XG_NETCLOSE : XG_NETERR;
}
int SSLSocket::read(void* data, int size)
{
	return read(data, size, true);
}
int SSLSocket::write(const void* data, int size)
{
	return write(data, size, true);
}
int SSLSocket::tryCheck(int timeout, bool ckrd)
{
	if (ssl == NULL) return XG_NETCLOSE;

	int res = SSL_pending(ssl);

	if (res > 0) return res;

	if (res < 0) return XG_NETERR;

	if (timeout <= 0) return XG_TIMEOUT;

	Sleep(timeout);

	res = SSL_pending(ssl);

	if (res > 0) return res;

	if (res < 0) return XG_NETERR;

	return XG_TIMEOUT;
}
int SSLSocket::read(void* data, int size, bool completed)
{
	if (ssl == NULL) return XG_NETCLOSE;

	if (completed)
	{
		int res = 0;
		int times = 0;
		int readed = 0;
		char* str = (char*)(data);

		SSL_peek(ssl, data, size);

		while (readed < size)
		{
			res = SSL_read(ssl, str + readed, size - readed);
			
			if (res > 0)
			{
				if (res > SOCKET_TIMEOUT_LIMITSIZE)
				{
					times = 0;
				}
				else
				{
					if (++times > SOCKET_TIMEOUT_REDOTIMES) return XG_TIMEOUT;
				}

				readed += res;
			}
			else
			{
				res = SSL_get_error(ssl, res);

				if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE)
				{
					if (++times > SOCKET_TIMEOUT_REDOTIMES) return XG_TIMEOUT;
				}
				else
				{
					return res == SSL_ERROR_ZERO_RETURN ? XG_NETCLOSE : XG_NETERR;
				}
			}
		}

		return readed;
	}
	else
	{
		int res = SSL_read(ssl, data, size);
		
		if (res > 0) return res;

		res = SSL_get_error(ssl, res);

		if (res == SSL_ERROR_WANT_READ) return 0;
		
		if (res == SSL_ERROR_WANT_WRITE) return 0;

		if (res == SSL_ERROR_ZERO_RETURN) return XG_NETCLOSE;

		return IsSocketTimeout() ? 0 : XG_NETERR;
	}
}
int SSLSocket::write(const void* data, int size, bool completed)
{
	if (ssl == NULL) return XG_NETCLOSE;

	if (completed)
	{
		int res = 0;
		int times = 0;
		int writed = 0;
		const char* ptr = (char*)(data);

		while (writed < size)
		{
			res = SSL_write(ssl, ptr + writed, min(SOCKECT_FRAMESIZE, size - writed));

			if (res > 0)
			{
				if (res > SOCKET_TIMEOUT_LIMITSIZE)
				{
					times = 0;
				}
				else
				{
					if (++times > SOCKET_TIMEOUT_REDOTIMES) return XG_TIMEOUT;
				}

				writed += res;
			}
			else
			{
				res = SSL_get_error(ssl, res);

				if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE)
				{
					if (++times > SOCKET_TIMEOUT_REDOTIMES) return XG_TIMEOUT;
				}
				else
				{
					return res == SSL_ERROR_ZERO_RETURN ? XG_NETCLOSE : XG_NETERR;
				}
			}
		}
		
		return writed;
	}
	else
	{
		int res = SSL_write(ssl, data, min(SOCKECT_FRAMESIZE, size));
		
		if (res > 0) return res;

		res = SSL_get_error(ssl, res);

		if (res == SSL_ERROR_WANT_READ) return 0;
		
		if (res == SSL_ERROR_WANT_WRITE) return 0;

		return res == SSL_ERROR_ZERO_RETURN ? XG_NETCLOSE : XG_NETERR;
	}
}

sp<Socket> SSLSocketPool::Connect(const string& host, int port)
{
	sp<SocketPool> pool = Get(host, port);

	if (pool) return pool->get();

	return Set(host, port, newsp<SSLSocketPool>(host, port))->get();
}
SSLSocketPool::SSLSocketPool(const string& host, int port) : SocketPool(host, port)
{
	creator = [&](){
		sp<SSLSocket> sock = newsp<SSLSocket>();

		if (sock->connect(this->host, this->port)) return sock;

		return sock = NULL;
	};
}
/////////////////////////////////////////////////////////////////
#endif